Adding UniSRec model implemented on lightweight class hierarchy with pytorch preprocessing #306
Adding UniSRec model implemented on lightweight class hierarchy with pytorch preprocessing #306TOPAPEC wants to merge 7 commits intoMTSWebServices:mainfrom
Conversation
Standalone sequential recommender package, mimics ModelBase interface without touching existing rectools code. FlatSASRec - plain ID-embedding SASRec encoder. UniSRec - pretrained text embeddings + PCA/BN adaptor, 3-phase training (ID emb -> adaptor only -> full finetune). Uses lightweight rank_topk instead of TorchRanker, reuses SASRecDataPreparator for the data pipeline. 30 tests, smoke scripts for both models. Fix: NaN*0=NaN in IEEE 754 breaks attention padding masking via multiplication, switched to masked_fill.
New config options: - ffn_type: conv1d / linear_gelu / linear_relu + ffn_expansion - optimizer: adam / adamw - scheduler: cosine_warmup (with warmup_ratio, min_lr_ratio) - loss: softmax / BCE / gBCE / sampled_softmax (with gbce_t) - patience: early stopping via EarlyStopping callback + val split - data_preparator: accept custom preparator instance 31 tests passing.
- Add hash-based ID mapping (splitmix64) as alternative to dense torch.unique mapping in build_sequences and align_embeddings. - Add UniSRecModel.export_to_onnx() for native ONNX export of encoder and item embeddings (project_all). - Add UniSRecModel.map_item_ids() for external→internal ID conversion at inference time (works for both dense and hash modes). - Remove FlatSASRecModel/FlatSASRecLightning (RecTools-coupled wrappers that duplicated UniSRecModel functionality). - Add tests: hash mapping (including string-derived IDs), ONNX export roundtrip, map_item_ids for both modes.
2e923df to
d68834f
Compare
There was a problem hiding this comment.
Pull request overview
Adds a new rectools.fast_transformers subpackage providing GPU-native preprocessing and standalone sequential transformer recommenders (FlatSASRec + UniSRec), plus ranking utilities, scripts, and comprehensive tests.
Changes:
- Introduces torch-native sequence building (
build_sequences), embedding alignment, and lightweight dataset/dataloader helpers. - Adds UniSRec (pretrained text embeddings + adaptor + SASRec encoder) with Lightning training wrapper and a standalone
UniSRecModelAPI (fit/checkpoint/ONNX export). - Adds
rank_topk()for batched scoring with CSR filtering + whitelist, along with benchmark scripts and extensive test coverage.
Reviewed changes
Copilot reviewed 17 out of 19 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
| rectools/fast_transformers/init.py | Exposes the new fast_transformers public API surface. |
| rectools/fast_transformers/gpu_data.py | Implements torch-native preprocessing utilities (sequence building, embedding alignment, dataloader helpers). |
| rectools/fast_transformers/net.py | Adds FlatSASRec network implementation. |
| rectools/fast_transformers/ranking.py | Adds rank_topk() batching + filtering + whitelist ranking utility. |
| rectools/fast_transformers/unisrec_lightning.py | Adds LightningModule wrapper (loss/optimizer/scheduler dispatch) for UniSRec training phases. |
| rectools/fast_transformers/unisrec_model.py | Adds standalone UniSRecModel (3-phase training, checkpointing, ONNX export, ID mapping). |
| rectools/fast_transformers/unisrec_net.py | Adds UniSRec network (adaptor + transformer encoder + helper methods). |
| tests/fast_transformers/init.py | Test package marker for fast_transformers. |
| tests/fast_transformers/test_gpu_data.py | Tests for sequence building, embedding alignment, dataset/dataloader, and hashing. |
| tests/fast_transformers/test_net.py | Tests for FlatSASRec forward paths and encoding helpers. |
| tests/fast_transformers/test_onnx_export.py | Tests ONNX export/roundtrip for UniSRec network and UniSRecModel export. |
| tests/fast_transformers/test_ranking.py | Tests top-k ranking, filtering, whitelist behavior, and edge cases. |
| tests/fast_transformers/test_unisrec_lightning.py | Tests UniSRecLightning configuration + loss/scheduler dispatch behavior. |
| tests/fast_transformers/test_unisrec_model.py | Tests UniSRecModel fit phases, losses/optimizers/schedulers, checkpointing, and mapping. |
| tests/fast_transformers/test_unisrec_net.py | Tests UniSRec network output shapes, adaptor variants, and freeze/unfreeze helpers. |
| scripts/compare_sasrec_unisrec.py | Benchmark script to compare RecTools SASRec vs UniSRec-ID and generate a report. |
| scripts/comparison_report.md | Adds a sample benchmark report output. |
| CHANGELOG.md | Documents the new module and features under Unreleased. |
| .gitignore | Ignores new dev artifacts, model weights, and data folders. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def build_sequences( | ||
| user_ids: torch.Tensor, | ||
| item_ids: torch.Tensor, | ||
| timestamps: torch.Tensor, | ||
| max_len: int, | ||
| min_interactions: int = 2, | ||
| device: str = "cuda", | ||
| id_mapping: str = "dense", | ||
| ) -> tp.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| user_ids = user_ids.to(device) | ||
| item_ids = item_ids.to(device) | ||
| timestamps = timestamps.to(device) |
| unique_items = torch.unique(item_ids) | ||
| n_unique = len(unique_items) | ||
|
|
||
| if id_mapping == "dense": | ||
| _, item_inv = torch.unique(item_ids, return_inverse=True) | ||
| internal_items = item_inv + 1 | ||
| elif id_mapping == "hash": |
| x, y, unique_items, unique_users = build_sequences( | ||
| user_ids, | ||
| item_ids, | ||
| timestamps, | ||
| max_len=self.session_max_len, | ||
| min_interactions=self.train_min_user_interactions, | ||
| id_mapping=self.id_mapping, | ||
| ) | ||
| self._unique_items = unique_items.cpu() | ||
| self._unique_users = unique_users.cpu() | ||
| n_items = len(unique_items) | ||
|
|
||
| aligned_emb = align_embeddings(self.pretrained_item_embeddings, unique_items, n_items, self.id_mapping) | ||
|
|
||
| net = UniSRec( | ||
| n_items=n_items, | ||
| pretrained_embeddings=aligned_emb, | ||
| n_factors=self.n_factors, | ||
| projection_hidden=self.projection_hidden, | ||
| n_blocks=self.n_blocks, | ||
| n_heads=self.n_heads, | ||
| session_max_len=self.session_max_len, | ||
| dropout=self.dropout, | ||
| adaptor_dropout=self.adaptor_dropout, | ||
| adaptor_type=self.adaptor_type, | ||
| use_adaptor_ffn=self.use_adaptor_ffn, | ||
| ffn_type=self.ffn_type, | ||
| ffn_expansion=self.ffn_expansion, | ||
| ) | ||
|
|
||
| train_dl = make_dataloader(x, y, batch_size=self.batch_size, shuffle=True) | ||
|
|
| lookup = {int(v): i + 1 for i, v in enumerate(self._unique_items.tolist())} | ||
| return torch.tensor([lookup.get(int(x), 0) for x in external_ids.tolist()], dtype=torch.long) | ||
|
|
| viewed_mask = torch.tensor(batch_csr.toarray(), dtype=torch.bool, device=device) | ||
| scores[viewed_mask] = -float("inf") | ||
|
|
| def test_padding_invariance(self, net: FlatSASRec) -> None: | ||
| """Different left-padding should produce same last-position embedding.""" | ||
| net.eval() | ||
| # Same content should produce identical output | ||
| x_a = torch.tensor([[0, 0, 0, 5, 10]]) | ||
| x_b = torch.tensor([[0, 0, 0, 5, 10]]) | ||
| with torch.no_grad(): | ||
| e_a = net.encode_last(x_a) | ||
| e_b = net.encode_last(x_b) | ||
| torch.testing.assert_close(e_a, e_b) |
| class TestPaddingInvariance: | ||
| def test_same_input_same_output(self, net: UniSRec) -> None: | ||
| net.eval() | ||
| x_a = torch.tensor([[0, 0, 0, 5, 10]]) | ||
| x_b = torch.tensor([[0, 0, 0, 5, 10]]) | ||
| with torch.no_grad(): | ||
| e_a = net.encode_last(x_a, use_id=False) | ||
| e_b = net.encode_last(x_b, use_id=False) | ||
| torch.testing.assert_close(e_a, e_b) |
| train_dl = make_dataloader(x, y, batch_size=self.batch_size, shuffle=True) | ||
|
|
||
| val_dl = None | ||
| if self.patience is not None: | ||
| val_y_last = y[:, -1:] | ||
| val_dl = make_dataloader(x, val_y_last, batch_size=self.batch_size, shuffle=False) |
| x, y, unique_items, unique_users = build_sequences( | ||
| user_ids, | ||
| item_ids, | ||
| timestamps, | ||
| max_len=self.session_max_len, | ||
| min_interactions=self.train_min_user_interactions, | ||
| id_mapping=self.id_mapping, | ||
| ) |
| max_len: int, | ||
| min_interactions: int = 2, | ||
| device: str = "cuda", | ||
| id_mapping: str = "dense", |
There was a problem hiding this comment.
Better to use Literal for such things
| min_interactions: int = 2, | ||
| device: str = "cuda", | ||
| id_mapping: str = "dense", | ||
| ) -> tp.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
There was a problem hiding this comment.
Please add extensive docstrings for all the public method, especially for those supposed to be used stand-alone. Here it's especially important since you're returning 4 tensors and user doesn't understand their meaning. Also good to add examples
| catboost_info/ | ||
|
|
||
| # Dev artifacts | ||
| training_folder/ |
There was a problem hiding this comment.
a bit weird name, can we remove it?
| - `align_embeddings()` for mapping pretrained embedding matrices to internal item ID order | ||
| - `GPUBatchDataset` and `make_dataloader()` — lightweight torch Dataset/DataLoader wrappers for sequence training data | ||
| - Configurable FFN blocks in `UniSRec`: `conv1d` (original paper), `linear_gelu`, `linear_relu` with adjustable expansion factor | ||
| - Tests for all `fast_transformers` submodules (143 tests) |
There was a problem hiding this comment.
We normally don't add anything that doesn't affect user directly to the changelog, so not much sense to write about the tests
There was a problem hiding this comment.
please put this script and the report to a subfolder in the benchmark folder
| return aligned | ||
|
|
||
|
|
||
| class GPUBatchDataset(TorchDataset): |
There was a problem hiding this comment.
I'm not sure the name reflect the purpose
- why GPU?
- what does Batch mean?
It also sounds quite "universal" even though I'd say it's more task-specific
| y: torch.Tensor, | ||
| batch_size: int, | ||
| shuffle: bool = True, | ||
| transform: tp.Optional[tp.Callable] = None, |
There was a problem hiding this comment.
I'd recommend to add **kwargs here to cover different parameters of data loader
On the other side I'm not sure it makes much sense to wrap 2 function calls in a separate function
| from scipy import sparse | ||
|
|
||
|
|
||
| def rank_topk( |
There was a problem hiding this comment.
Sorry, I'm too lazy to check, could you please describe why do we need it given that we have TorchRanker? Could we reuse the code?
New
rectools.fast_transformersmodule — standalone transformer sequential recommenders that work with raw torch tensors, without going throughDataset/pandas.GPU-native preprocessing.
build_sequences()builds left-padded interaction sequences entirely in torch (argsort + scatter). On ML-20M (20M interactions) this takes 0.5s vs 14.6s for the pandas-basedSASRecDataPreparator— roughly 30x faster. For larger production data the problem is even worse - KION prod dataset aggregation for a period of only half a year takes up to 50 minutes only on current rectools code to preprocess data, while train take comparable time to finish.FlatSASRec. Pre-norm SASRec encoder with plain id-embeddings, no ItemNet hierarchy. Wraps into
FlatSASRecModel(inheritsModelBase) so it plugs into standard RecTools fit/recommend.UniSRec. Three-phase sequential recommender with pretrained text embeddings and a learnable PCA adaptor:
UniSRecModel.fit(user_ids, item_ids, timestamps)takes raw tensors end-to-end. Supports softmax/BCE/gBCE/sampled_softmax losses, Adam/AdamW, cosine warmup scheduler, gradient clipping, early stopping, checkpoint save/load. FFN blocks are configurable (conv1d, linear_gelu, linear_relu).rank_topk()— batched top-k with CSR viewed-item filtering and whitelist support.Benchmark (ML-20M, 10 epochs, softmax, Adam, n_factors=256)
UniSRec ID: +4.6% HR@10, +6.0% NDCG@10, 1.65x faster overall.
New files
Source (9 modules, 1683 lines):
rectools/fast_transformers/gpu_data.py—build_sequences,align_embeddings,GPUBatchDataset,make_dataloaderrectools/fast_transformers/net.py—FlatSASRec,SASRecBlockrectools/fast_transformers/lightning_wrap.py—FlatSASRecLightningrectools/fast_transformers/model.py—FlatSASRecModel,FlatSASRecConfigrectools/fast_transformers/ranking.py—rank_topkrectools/fast_transformers/unisrec_net.py—UniSRec,FeedForward,make_ffnrectools/fast_transformers/unisrec_lightning.py—UniSRecLightning, loss/optimizer/scheduler dispatchrectools/fast_transformers/unisrec_model.py—UniSRecModel(three-phase fit, checkpoint)Tests (143 tests, 1920 lines):
tests/fast_transformers/test_gpu_data.py— sequence building, alignment, dataset/dataloadertests/fast_transformers/test_net.py,test_lightning_wrap.py,test_model.py— FlatSASRec stacktests/fast_transformers/test_unisrec_net.py,test_unisrec_lightning.py,test_unisrec_model.py— UniSRec stacktests/fast_transformers/test_ranking.py— top-k, filtering, edge casesScripts:
scripts/compare_sasrec_unisrec.py— full benchmark with markdown report generationscripts/comparison_report.md— benchmark resultsTest plan
pytest tests/fast_transformers/ -q)FlatSASRecModelfit/recommend through the standard RecTools API on a small dataset